Initialize PySpark

First, we use the findspark package to initialize PySpark.


In [1]:
# Initialize PySpark
APP_NAME = "Debugging Prediction Problems"

# If there is no SparkSession, create the environment
try:
  sc and spark
except NameError as e:
  import findspark
  findspark.init()
  import pyspark
  import pyspark.sql

  sc = pyspark.SparkContext()
  spark = pyspark.sql.SparkSession(sc).builder.appName(APP_NAME).getOrCreate()

print("PySpark initiated...")


PySpark initiated...

Hello, World!

Loading data, mapping it and collecting the records into RAM...


In [2]:
# Load the text file using the SparkContext
csv_lines = sc.textFile("../data/example.csv")

# Map the data to split the lines into a list
data = csv_lines.map(lambda line: line.split(","))

# Collect the dataset into local RAM
data.collect()


Out[2]:
[['Russell Jurney', 'Relato', 'CEO'],
 ['Florian Liebert', 'Mesosphere', 'CEO'],
 ['Don Brown', 'Rocana', 'CIO'],
 ['Steve Jobs', 'Apple', 'CEO'],
 ['Donald Trump', 'The Trump Organization', 'CEO'],
 ['Russell Jurney', 'Data Syndrome', 'Principal Consultant']]

Creating Objects from CSV

Using a function with a map operation to create objects (dicts) as records...


In [3]:
# Turn the CSV lines into objects
def csv_to_record(line):
  parts = line.split(",")
  record = {
    "name": parts[0],
    "company": parts[1],
    "title": parts[2]
  }
  return record

# Apply the function to every record
records = csv_lines.map(csv_to_record)

# Inspect the first item in the dataset
records.first()


Out[3]:
{'company': 'Relato', 'name': 'Russell Jurney', 'title': 'CEO'}

GroupBy

Using the groupBy operator to count the number of jobs per person...


In [4]:
# Group the records by the name of the person
grouped_records = records.groupBy(lambda x: x["name"])

# Show the first group
grouped_records.first()

# Count the groups
job_counts = grouped_records.map(
  lambda x: {
    "name": x[0],
    "job_count": len(x[1])
  }
)

job_counts.first()

job_counts.collect()


Out[4]:
[{'job_count': 1, 'name': 'Florian Liebert'},
 {'job_count': 1, 'name': 'Don Brown'},
 {'job_count': 2, 'name': 'Russell Jurney'},
 {'job_count': 1, 'name': 'Donald Trump'},
 {'job_count': 1, 'name': 'Steve Jobs'}]

Map vs FlatMap

Understanding the difference between the map and flatmap operators...


In [5]:
# Compute a relation of words by line
words_by_line = csv_lines\
  .map(lambda line: line.split(","))

print(words_by_line.collect())

# Compute a relation of words
flattened_words = csv_lines\
  .map(lambda line: line.split(","))\
  .flatMap(lambda x: x)

flattened_words.collect()


[['Russell Jurney', 'Relato', 'CEO'], ['Florian Liebert', 'Mesosphere', 'CEO'], ['Don Brown', 'Rocana', 'CIO'], ['Steve Jobs', 'Apple', 'CEO'], ['Donald Trump', 'The Trump Organization', 'CEO'], ['Russell Jurney', 'Data Syndrome', 'Principal Consultant']]
Out[5]:
['Russell Jurney',
 'Relato',
 'CEO',
 'Florian Liebert',
 'Mesosphere',
 'CEO',
 'Don Brown',
 'Rocana',
 'CIO',
 'Steve Jobs',
 'Apple',
 'CEO',
 'Donald Trump',
 'The Trump Organization',
 'CEO',
 'Russell Jurney',
 'Data Syndrome',
 'Principal Consultant']

Creating Rows

Creating pyspark.sql.Rows out of your data so you can create DataFrames...


In [6]:
from pyspark.sql import Row

# Convert the CSV into a pyspark.sql.Row
def csv_to_row(line):
  parts = line.split(",")
  row = Row(
    name=parts[0],
    company=parts[1],
    title=parts[2]
  )
  return row

# Apply the function to get rows in an RDD
rows = csv_lines.map(csv_to_row)

Creating DataFrames from RDDs

Using the RDD.toDF() method to create a dataframe, registering the DataFrame as a temporary table with Spark SQL, and counting the jobs per person using Spark SQL.


In [7]:
# Convert to a pyspark.sql.DataFrame
rows_df = rows.toDF()

# Register the DataFrame for Spark SQL
rows_df.registerTempTable("executives")

# Generate a new DataFrame with SQL using the SparkSession
job_counts = spark.sql("""
SELECT
  name,
  COUNT(*) AS total
  FROM executives
  GROUP BY name
""")
job_counts.show()

# Go back to an RDD
job_counts.rdd.collect()


+---------------+-----+
|           name|total|
+---------------+-----+
|   Donald Trump|    1|
|Florian Liebert|    1|
|      Don Brown|    1|
| Russell Jurney|    2|
|     Steve Jobs|    1|
+---------------+-----+

Out[7]:
[Row(name='Donald Trump', total=1),
 Row(name='Florian Liebert', total=1),
 Row(name='Don Brown', total=1),
 Row(name='Russell Jurney', total=2),
 Row(name='Steve Jobs', total=1)]

Loading and Inspecting Parquet Files

Using the SparkSession to load files as DataFrames and inspecting their contents...


In [8]:
# Load the parquet file containing flight delay records
on_time_dataframe = spark.read.parquet('../data/on_time_performance.parquet')

# Register the data for Spark SQL
on_time_dataframe.registerTempTable("on_time_performance")

# Check out the columns
on_time_dataframe.columns

# Check out some data
on_time_dataframe\
  .select("FlightDate", "TailNum", "Origin", "Dest", "Carrier", "DepDelay", "ArrDelay")\
  .show()

# Trim the fields and keep the result
trimmed_on_time = on_time_dataframe\
  .select(
    "FlightDate",
    "TailNum",
    "Origin",
    "Dest",
    "Carrier",
    "DepDelay",
    "ArrDelay"
  )

# Sample 0.01% of the data and show
trimmed_on_time.sample(False, 0.0001).show()


+----------+-------+------+----+-------+--------+--------+
|FlightDate|TailNum|Origin|Dest|Carrier|DepDelay|ArrDelay|
+----------+-------+------+----+-------+--------+--------+
|2015-01-01| N001AA|   DFW| MEM|     AA|    -3.0|    -6.0|
|2015-01-01| N001AA|   MEM| DFW|     AA|    -4.0|    -9.0|
|2015-01-01| N002AA|   ORD| DFW|     AA|     0.0|    26.0|
|2015-01-01| N003AA|   DFW| ATL|     AA|   100.0|   112.0|
|2015-01-01| N003AA|   DFW| HDN|     AA|    78.0|    78.0|
|2015-01-01| N003AA|   HDN| DFW|     AA|   332.0|   336.0|
|2015-01-01| N004AA|   JAC| DFW|     AA|    -4.0|    21.0|
|2015-01-01| N005AA|   EGE| ORD|     AA|    null|    null|
|2015-01-01| N005AA|   ORD| EGE|     AA|    null|    null|
|2015-01-01| N005AA|   DFW| ORD|     AA|    null|    null|
|2015-01-01| N006AA|   DFW| ATL|     AA|    null|    null|
|2015-01-01| N006AA|   ATL| DFW|     AA|    -5.0|     1.0|
|2015-01-01| N006AA|   DFW| ATL|     AA|    -4.0|   -11.0|
|2015-01-01| N007AA|   DFW| ATL|     AA|    76.0|    86.0|
|2015-01-01| N008AA|   ATL| DFW|     AA|    -2.0|    -7.0|
|2015-01-01| N008AA|   DFW| ATL|     AA|    -5.0|   -25.0|
|2015-01-01| N009AA|   DFW| EGE|     AA|    35.0|    17.0|
|2015-01-01| N009AA|   EGE| LAX|     AA|    10.0|   -12.0|
|2015-01-01| N010AA|   DFW| SDF|     AA|    null|    null|
|2015-01-01| N010AA|   SDF| DFW|     AA|    null|    null|
+----------+-------+------+----+-------+--------+--------+
only showing top 20 rows

+----------+-------+------+----+-------+--------+--------+
|FlightDate|TailNum|Origin|Dest|Carrier|DepDelay|ArrDelay|
+----------+-------+------+----+-------+--------+--------+
|2015-01-01| N821AW|   MCO| PHL|     US|     3.0|    -4.0|
|2015-01-10| N284WN|   BWI| SAN|     WN|    41.0|    22.0|
|2015-01-10| N416WN|   MSP| MDW|     WN|    37.0|    25.0|
|2015-01-11| N662JB|   BOS| FLL|     B6|    -1.0|   -20.0|
|2015-01-13| N440UA|   ORD| DEN|     UA|    20.0|    25.0|
|2015-01-15| N588NK|   FLL| SJU|     NK|    -5.0|   -14.0|
|2015-01-16| N685BR|   GTF| SLC|     OO|    -6.0|     6.0|
|2015-01-16| N7720F|   SNA| PHX|     WN|    18.0|     2.0|
|2015-01-16| N7726A|   OAK| SNA|     WN|    -1.0|    -7.0|
|2015-01-17| N3CTAA|   SFO| DFW|     AA|    -2.0|   -18.0|
|2015-01-18| N980EV|   ATL| FSM|     EV|    -7.0|     2.0|
|2015-01-19| N944UW|   DCA| LGA|     US|    -7.0|     0.0|
|2015-01-02| N531NK|   FLL| DTW|     NK|    61.0|    50.0|
|2015-01-02| N914WN|   DAL| DEN|     WN|     0.0|   -21.0|
|2015-01-20| N649MQ|   DFW| GCK|     MQ|     2.0|     3.0|
|2015-01-20| N682MQ|   MIA| EYW|     MQ|    -5.0|   -13.0|
|2015-01-20| N465SW|   RNO| LAX|     OO|    -6.0|    -6.0|
|2015-01-20| N69834|   MCO| ORD|     UA|    -2.0|   -10.0|
|2015-01-22| N965DN|   MCI| DTW|     DL|    -3.0|   -27.0|
|2015-01-23| N488AA|   DFW| DAY|     AA|    -1.0|    18.0|
+----------+-------+------+----+-------+--------+--------+
only showing top 20 rows

Calculating Histograms

Using RDDs to calculate histograms buckets and values...


In [10]:
# Compute a histogram of departure delays
on_time_dataframe\
  .select("DepDelay")\
  .rdd\
  .flatMap(lambda x: x)\
  .histogram(10)


Out[10]:
([-82.0,
  125.0,
  332.0,
  539.0,
  746.0,
  953.0,
  1160.0,
  1367.0,
  1574.0,
  1781.0,
  1988.0],
 [11247596, 201786, 12808, 2026, 972, 422, 152, 68, 18, 4])

Visualizing Histograms

Using pyplot to visualize histograms...


In [14]:
%matplotlib inline

import numpy as np
import matplotlib.mlab as mlab
import matplotlib.pyplot as plt

# Function to plot a histogram using pyplot
def create_hist(rdd_histogram_data):
  """Given an RDD.histogram, plot a pyplot histogram"""
  heights = np.array(rdd_histogram_data[1])
  full_bins = rdd_histogram_data[0]
  mid_point_bins = full_bins[:-1]
  widths = [abs(i - j) for i, j in zip(full_bins[:-1], full_bins[1:])]
  bar = plt.bar(mid_point_bins, heights, width=widths, color='b')
  return bar

# Compute a histogram of departure delays
departure_delay_histogram = on_time_dataframe\
  .select("DepDelay")\
  .rdd\
  .flatMap(lambda x: x)\
  .histogram([-60,-30,-15,-10,-5,0,5,10,15,30,60,90,120,180])

create_hist(departure_delay_histogram)


Out[14]:
<Container object of 13 artists>

Counting Airplanes in the US Fleet


In [12]:
# Dump the unneeded fields
tail_numbers = on_time_dataframe.rdd.map(lambda x: x.TailNum)
tail_numbers = tail_numbers.filter(lambda x: x != '')

# distinct() gets us unique tail numbers
unique_tail_numbers = tail_numbers.distinct()

# now we need a count() of unique tail numbers
airplane_count = unique_tail_numbers.count()
print("Total airplanes: {}".format(airplane_count))


Total airplanes: 4898

Counting the Total Flights Per Month


In [15]:
# Use SQL to look at the total flights by month across 2015
on_time_dataframe.registerTempTable("on_time_dataframe")
total_flights_by_month = spark.sql(
  """SELECT Month, Year, COUNT(*) AS total_flights
  FROM on_time_dataframe
  GROUP BY Year, Month
  ORDER BY Year, Month"""
)

# This map/asDict trick makes the rows print a little prettier. It is optional.
flights_chart_data = total_flights_by_month.rdd.map(lambda row: row.asDict())
flights_chart_data.collect()


Out[15]:
[{'Month': '1', 'Year': '2015', 'total_flights': 939936},
 {'Month': '10', 'Year': '2015', 'total_flights': 972330},
 {'Month': '11', 'Year': '2015', 'total_flights': 935944},
 {'Month': '12', 'Year': '2015', 'total_flights': 958460},
 {'Month': '2', 'Year': '2015', 'total_flights': 858382},
 {'Month': '3', 'Year': '2015', 'total_flights': 1008624},
 {'Month': '4', 'Year': '2015', 'total_flights': 970302},
 {'Month': '5', 'Year': '2015', 'total_flights': 993986},
 {'Month': '6', 'Year': '2015', 'total_flights': 1007794},
 {'Month': '7', 'Year': '2015', 'total_flights': 1041436},
 {'Month': '8', 'Year': '2015', 'total_flights': 1021072},
 {'Month': '9', 'Year': '2015', 'total_flights': 929892}]

Using RDDs and Map/Reduce to Prepare a Complex Record


In [16]:
# Filter down to the fields we need to identify and link to a flight
flights = on_time_dataframe.rdd.map(lambda x: 
  (x.Carrier, x.FlightDate, x.FlightNum, x.Origin, x.Dest, x.TailNum)
  )

# Group flights by tail number, sorted by date, then flight number, then origin/dest
flights_per_airplane = flights\
  .map(lambda nameTuple: (nameTuple[5], [nameTuple[0:5]]))\
  .reduceByKey(lambda a, b: a + b)\
  .map(lambda tuple:
      {
        'TailNum': tuple[0], 
        'Flights': sorted(tuple[1], key=lambda x: (x[1], x[2], x[3], x[4]))
      }
    )
flights_per_airplane.first()


Out[16]:
{'Flights': [('UA', '2015-09-05', '1132', 'DEN', 'LAX'),
  ('UA', '2015-09-05', '1132', 'DEN', 'LAX'),
  ('UA', '2015-09-05', '1169', 'IAH', 'LAX'),
  ('UA', '2015-09-05', '1169', 'IAH', 'LAX'),
  ('UA', '2015-09-05', '1804', 'LAX', 'DEN'),
  ('UA', '2015-09-05', '1804', 'LAX', 'DEN'),
  ('UA', '2015-09-06', '1977', 'LAX', 'IAH'),
  ('UA', '2015-09-06', '1977', 'LAX', 'IAH'),
  ('UA', '2015-09-08', '1423', 'IAH', 'LAX'),
  ('UA', '2015-09-08', '1423', 'IAH', 'LAX'),
  ('UA', '2015-09-08', '1693', 'LAX', 'IAH'),
  ('UA', '2015-09-08', '1693', 'LAX', 'IAH'),
  ('UA', '2015-09-09', '1614', 'IAH', 'SFO'),
  ('UA', '2015-09-09', '1614', 'IAH', 'SFO'),
  ('UA', '2015-09-09', '1937', 'SFO', 'IAH'),
  ('UA', '2015-09-09', '1937', 'SFO', 'IAH'),
  ('UA', '2015-09-10', '1169', 'IAH', 'LAX'),
  ('UA', '2015-09-10', '1169', 'IAH', 'LAX'),
  ('UA', '2015-09-10', '1631', 'DEN', 'LAX'),
  ('UA', '2015-09-10', '1631', 'DEN', 'LAX'),
  ('UA', '2015-09-10', '1653', 'LAX', 'DEN'),
  ('UA', '2015-09-10', '1653', 'LAX', 'DEN'),
  ('UA', '2015-09-10', '1977', 'LAX', 'IAH'),
  ('UA', '2015-09-10', '1977', 'LAX', 'IAH'),
  ('UA', '2015-09-11', '1181', 'SFO', 'IAH'),
  ('UA', '2015-09-11', '1181', 'SFO', 'IAH'),
  ('UA', '2015-09-11', '1614', 'IAH', 'SFO'),
  ('UA', '2015-09-11', '1614', 'IAH', 'SFO'),
  ('UA', '2015-09-12', '1715', 'IAH', 'LAX'),
  ('UA', '2015-09-12', '1715', 'IAH', 'LAX'),
  ('UA', '2015-09-12', '1977', 'LAX', 'IAH'),
  ('UA', '2015-09-12', '1977', 'LAX', 'IAH'),
  ('UA', '2015-09-21', '1126', 'IAH', 'DEN'),
  ('UA', '2015-09-21', '1126', 'IAH', 'DEN'),
  ('UA', '2015-09-22', '1905', 'DEN', 'IAH'),
  ('UA', '2015-09-22', '1905', 'DEN', 'IAH'),
  ('UA', '2015-09-24', '1126', 'IAH', 'DEN'),
  ('UA', '2015-09-24', '1126', 'IAH', 'DEN'),
  ('UA', '2015-09-25', '1051', 'DEN', 'IAH'),
  ('UA', '2015-09-25', '1051', 'DEN', 'IAH'),
  ('UA', '2015-09-29', '1126', 'IAH', 'DEN'),
  ('UA', '2015-09-29', '1126', 'IAH', 'DEN'),
  ('UA', '2015-09-30', '1051', 'DEN', 'IAH'),
  ('UA', '2015-09-30', '1051', 'DEN', 'IAH'),
  ('UA', '2015-10-08', '1169', 'IAH', 'LAX'),
  ('UA', '2015-10-08', '1169', 'IAH', 'LAX'),
  ('UA', '2015-10-19', '1169', 'IAH', 'LAX'),
  ('UA', '2015-10-19', '1169', 'IAH', 'LAX'),
  ('UA', '2015-10-23', '1405', 'LAX', 'DEN'),
  ('UA', '2015-10-23', '1405', 'LAX', 'DEN'),
  ('UA', '2015-10-23', '1728', 'DEN', 'LAX'),
  ('UA', '2015-10-23', '1728', 'DEN', 'LAX'),
  ('UA', '2015-10-29', '1963', 'LAX', 'IAH'),
  ('UA', '2015-10-29', '1963', 'LAX', 'IAH'),
  ('UA', '2015-10-31', '1030', 'IAH', 'DEN'),
  ('UA', '2015-10-31', '1030', 'IAH', 'DEN'),
  ('UA', '2015-10-31', '1877', 'DEN', 'IAH'),
  ('UA', '2015-10-31', '1877', 'DEN', 'IAH'),
  ('UA', '2015-11-03', '1169', 'IAH', 'LAX'),
  ('UA', '2015-11-03', '1169', 'IAH', 'LAX'),
  ('UA', '2015-11-12', '1963', 'LAX', 'IAH'),
  ('UA', '2015-11-12', '1963', 'LAX', 'IAH'),
  ('UA', '2015-11-14', '1865', 'DEN', 'IAH'),
  ('UA', '2015-11-14', '1865', 'DEN', 'IAH'),
  ('UA', '2015-11-14', '328', 'IAH', 'DEN'),
  ('UA', '2015-11-14', '328', 'IAH', 'DEN'),
  ('UA', '2015-11-19', '1051', 'DEN', 'IAH'),
  ('UA', '2015-11-19', '1051', 'DEN', 'IAH'),
  ('UA', '2015-11-19', '328', 'IAH', 'DEN'),
  ('UA', '2015-11-19', '328', 'IAH', 'DEN'),
  ('UA', '2015-11-22', '1169', 'IAH', 'LAX'),
  ('UA', '2015-11-22', '1169', 'IAH', 'LAX'),
  ('UA', '2015-11-26', '1963', 'LAX', 'IAH'),
  ('UA', '2015-11-26', '1963', 'LAX', 'IAH'),
  ('UA', '2015-11-28', '1169', 'IAH', 'LAX'),
  ('UA', '2015-11-28', '1169', 'IAH', 'LAX'),
  ('UA', '2015-12-02', '1963', 'LAX', 'IAH'),
  ('UA', '2015-12-02', '1963', 'LAX', 'IAH'),
  ('UA', '2015-12-04', '1169', 'IAH', 'LAX'),
  ('UA', '2015-12-04', '1169', 'IAH', 'LAX'),
  ('UA', '2015-12-16', '1963', 'LAX', 'IAH'),
  ('UA', '2015-12-16', '1963', 'LAX', 'IAH'),
  ('UA', '2015-12-18', '1169', 'IAH', 'LAX'),
  ('UA', '2015-12-18', '1169', 'IAH', 'LAX'),
  ('UA', '2015-12-22', '1963', 'LAX', 'IAH'),
  ('UA', '2015-12-22', '1963', 'LAX', 'IAH'),
  ('UA', '2015-12-24', '2011', 'DEN', 'IAH'),
  ('UA', '2015-12-24', '2011', 'DEN', 'IAH'),
  ('UA', '2015-12-24', '507', 'IAH', 'DEN'),
  ('UA', '2015-12-24', '507', 'IAH', 'DEN'),
  ('UA', '2015-12-28', '1169', 'IAH', 'LAX'),
  ('UA', '2015-12-28', '1169', 'IAH', 'LAX')],
 'TailNum': 'N27957'}

Counting Late Flights


In [ ]:
total_flights = on_time_dataframe.count()

# Flights that were late leaving...
late_departures = on_time_dataframe.filter(
  on_time_dataframe.DepDelayMinutes > 0
)
total_late_departures = late_departures.count()
print(total_late_departures)

# Flights that were late arriving...
late_arrivals = on_time_dataframe.filter(
  on_time_dataframe.ArrDelayMinutes > 0
)
total_late_arrivals = late_arrivals.count()
print(total_late_arrivals)

# Get the percentage of flights that are late, rounded to 1 decimal place
pct_late = round((total_late_arrivals / (total_flights * 1.0)) * 100, 1)

Counting Flights with Hero Captains

"Hero Captains" are those that depart late but make up time in the air and arrive on time or early.


In [ ]:
# Flights that left late but made up time to arrive on time...
on_time_heros = on_time_dataframe.filter(
  (on_time_dataframe.DepDelayMinutes > 0)
  &
  (on_time_dataframe.ArrDelayMinutes <= 0)
)
total_on_time_heros = on_time_heros.count()
print(total_on_time_heros)

Printing Our Results


In [20]:
print("Total flights:   {:,}".format(total_flights))
print("Late departures: {:,}".format(total_late_departures))
print("Late arrivals:   {:,}".format(total_late_arrivals))
print("Recoveries:      {:,}".format(total_on_time_heros))
print("Percentage Late: {}%".format(pct_late))


Total flights:   11,638,158
Late departures: 4,251,236
Late arrivals:   4,173,792
Recoveries:      1,213,804
Percentage Late: 35.9%

Computing the Average Lateness Per Flights


In [21]:
# Get the average minutes late departing and arriving
spark.sql("""
SELECT
  ROUND(AVG(DepDelay),1) AS AvgDepDelay,
  ROUND(AVG(ArrDelay),1) AS AvgArrDelay
FROM on_time_performance
"""
).show()


+-----------+-----------+
|AvgDepDelay|AvgArrDelay|
+-----------+-----------+
|        9.4|        4.4|
+-----------+-----------+

Inspecting Late Flights


In [22]:
# Why are flights late? Lets look at some delayed flights and the delay causes
late_flights = spark.sql("""
SELECT
  ArrDelayMinutes,
  WeatherDelay,
  CarrierDelay,
  NASDelay,
  SecurityDelay,
  LateAircraftDelay
FROM
  on_time_performance
WHERE
  WeatherDelay IS NOT NULL
  OR
  CarrierDelay IS NOT NULL
  OR
  NASDelay IS NOT NULL
  OR
  SecurityDelay IS NOT NULL
  OR
  LateAircraftDelay IS NOT NULL
ORDER BY
  FlightDate
""")
late_flights.sample(False, 0.01).show()


+---------------+------------+------------+--------+-------------+-----------------+
|ArrDelayMinutes|WeatherDelay|CarrierDelay|NASDelay|SecurityDelay|LateAircraftDelay|
+---------------+------------+------------+--------+-------------+-----------------+
|           25.0|         0.0|        25.0|     0.0|          0.0|              0.0|
|          147.0|         0.0|        20.0|     0.0|          0.0|            127.0|
|           17.0|         0.0|         0.0|    17.0|          0.0|              0.0|
|           18.0|         0.0|         0.0|    18.0|          0.0|              0.0|
|           47.0|         0.0|         0.0|    47.0|          0.0|              0.0|
|          235.0|         0.0|         0.0|     0.0|          0.0|            235.0|
|           23.0|         0.0|        23.0|     0.0|          0.0|              0.0|
|           23.0|         0.0|         1.0|    17.0|          0.0|              5.0|
|           28.0|         0.0|         0.0|    28.0|          0.0|              0.0|
|           51.0|         0.0|        35.0|    11.0|          0.0|              5.0|
|           35.0|         0.0|         0.0|    35.0|          0.0|              0.0|
|           38.0|         0.0|         0.0|    32.0|          0.0|              6.0|
|           40.0|        17.0|         0.0|    11.0|          0.0|             12.0|
|           20.0|         0.0|         0.0|    20.0|          0.0|              0.0|
|           20.0|         0.0|         0.0|     3.0|          0.0|             17.0|
|           69.0|         0.0|         0.0|    13.0|          0.0|             56.0|
|           31.0|         0.0|        25.0|     6.0|          0.0|              0.0|
|           56.0|         0.0|         9.0|    47.0|          0.0|              0.0|
|           87.0|         0.0|        87.0|     0.0|          0.0|              0.0|
|           53.0|         0.0|        53.0|     0.0|          0.0|              0.0|
+---------------+------------+------------+--------+-------------+-----------------+
only showing top 20 rows

Determining Why Flights Are Late


In [23]:
# Calculate the percentage contribution to delay for each source
total_delays = spark.sql("""
SELECT
  ROUND(SUM(WeatherDelay)/SUM(ArrDelayMinutes) * 100, 1) AS pct_weather_delay,
  ROUND(SUM(CarrierDelay)/SUM(ArrDelayMinutes) * 100, 1) AS pct_carrier_delay,
  ROUND(SUM(NASDelay)/SUM(ArrDelayMinutes) * 100, 1) AS pct_nas_delay,
  ROUND(SUM(SecurityDelay)/SUM(ArrDelayMinutes) * 100, 1) AS pct_security_delay,
  ROUND(SUM(LateAircraftDelay)/SUM(ArrDelayMinutes) * 100, 1) AS pct_late_aircraft_delay
FROM on_time_performance
""")
total_delays.show()


+-----------------+-----------------+-------------+------------------+-----------------------+
|pct_weather_delay|pct_carrier_delay|pct_nas_delay|pct_security_delay|pct_late_aircraft_delay|
+-----------------+-----------------+-------------+------------------+-----------------------+
|              4.5|             29.2|         20.7|               0.1|                   36.1|
+-----------------+-----------------+-------------+------------------+-----------------------+

Computing a Histogram of Weather Delayed Flights


In [26]:
# Eyeball the first to define our buckets
weather_delay_histogram = on_time_dataframe\
  .select("WeatherDelay")\
  .rdd\
  .flatMap(lambda x: x)\
  .histogram([1, 5, 10, 15, 30, 60, 120, 240, 480, 720, 24*60.0])
print(weather_delay_histogram)


([1, 5, 10, 15, 30, 60, 120, 240, 480, 720, 1440.0], [10872, 15336, 13272, 32014, 27138, 18884, 9196, 2272, 304, 144])

In [25]:
create_hist(weather_delay_histogram)


Out[25]:
<Container object of 8 artists>

Preparing a Histogram for Visualization by d3.js


In [ ]:
# Transform the data into something easily consumed by d3
def histogram_to_publishable(histogram):
  record = {'key': 1, 'data': []}
  for label, value in zip(histogram[0], histogram[1]):
    record['data'].append(
      {
        'label': label,
        'value': value
      }
    )
  return record

# Recompute the weather histogram with a filter for on-time flights
weather_delay_histogram = on_time_dataframe\
  .filter(
    (on_time_dataframe.WeatherDelay != None)
    &
    (on_time_dataframe.WeatherDelay > 0)
  )\
  .select("WeatherDelay")\
  .rdd\
  .flatMap(lambda x: x)\
  .histogram([0, 15, 30, 60, 120, 240, 480, 720, 24*60.0])
print(weather_delay_histogram)

record = histogram_to_publishable(weather_delay_histogram)
record

Building a Classifier Model to Predict Flight Delays

Loading Our Data


In [31]:
from pyspark.sql.types import StringType, IntegerType, FloatType, DoubleType, DateType, TimestampType
from pyspark.sql.types import StructType, StructField
from pyspark.sql.functions import udf

schema = StructType([
  StructField("ArrDelay", DoubleType(), True),     # "ArrDelay":5.0
  StructField("CRSArrTime", TimestampType(), True),    # "CRSArrTime":"2015-12-31T03:20:00.000-08:00"
  StructField("CRSDepTime", TimestampType(), True),    # "CRSDepTime":"2015-12-31T03:05:00.000-08:00"
  StructField("Carrier", StringType(), True),     # "Carrier":"WN"
  StructField("DayOfMonth", IntegerType(), True), # "DayOfMonth":31
  StructField("DayOfWeek", IntegerType(), True),  # "DayOfWeek":4
  StructField("DayOfYear", IntegerType(), True),  # "DayOfYear":365
  StructField("DepDelay", DoubleType(), True),     # "DepDelay":14.0
  StructField("Dest", StringType(), True),        # "Dest":"SAN"
  StructField("Distance", DoubleType(), True),     # "Distance":368.0
  StructField("FlightDate", DateType(), True),    # "FlightDate":"2015-12-30T16:00:00.000-08:00"
  StructField("FlightNum", StringType(), True),   # "FlightNum":"6109"
  StructField("Origin", StringType(), True),      # "Origin":"TUS"
])

features = spark.read.json(
  "../data/simple_flight_delay_features.jsonl.bz2",
  schema=schema
)
features.first()


Out[31]:
Row(ArrDelay=13.0, CRSArrTime=datetime.datetime(2015, 1, 1, 10, 10), CRSDepTime=datetime.datetime(2015, 1, 1, 7, 30), Carrier='AA', DayOfMonth=1, DayOfWeek=4, DayOfYear=1, DepDelay=14.0, Dest='DFW', Distance=569.0, FlightDate=datetime.date(2014, 12, 31), FlightNum='1024', Origin='ABQ')

Check Data for Nulls


In [32]:
#
# Check for nulls in features before using Spark ML
#
null_counts = [(column, features.where(features[column].isNull()).count()) for column in features.columns]
cols_with_nulls = filter(lambda x: x[1] > 0, null_counts)
print(list(cols_with_nulls))


[]

Add a Route Column

Demonstrating the addition of a feature to our model...


In [33]:
#
# Add a Route variable to replace FlightNum
#
from pyspark.sql.functions import lit, concat

features_with_route = features.withColumn(
  'Route',
  concat(
    features.Origin,
    lit('-'),
    features.Dest
  )
)
features_with_route.select("Origin", "Dest", "Route").show(5)


+------+----+-------+
|Origin|Dest|  Route|
+------+----+-------+
|   ABQ| DFW|ABQ-DFW|
|   ABQ| DFW|ABQ-DFW|
|   ABQ| DFW|ABQ-DFW|
|   ATL| DFW|ATL-DFW|
|   ATL| DFW|ATL-DFW|
+------+----+-------+
only showing top 5 rows

Bucketizing ArrDelay into ArrDelayBucket


In [34]:
#
# Use pysmark.ml.feature.Bucketizer to bucketize ArrDelay
#
from pyspark.ml.feature import Bucketizer

splits = [-float("inf"), -15.0, 0, 30.0, float("inf")]
bucketizer = Bucketizer(
  splits=splits,
  inputCol="ArrDelay",
  outputCol="ArrDelayBucket"
)
ml_bucketized_features = bucketizer.transform(features_with_route)

# Check the buckets out
ml_bucketized_features.select("ArrDelay", "ArrDelayBucket").show()


+--------+--------------+
|ArrDelay|ArrDelayBucket|
+--------+--------------+
|    13.0|           2.0|
|    17.0|           2.0|
|    36.0|           3.0|
|   -21.0|           0.0|
|   -14.0|           1.0|
|    16.0|           2.0|
|    -7.0|           1.0|
|    13.0|           2.0|
|    25.0|           2.0|
|    58.0|           3.0|
|    14.0|           2.0|
|     1.0|           2.0|
|   -29.0|           0.0|
|   -10.0|           1.0|
|    -3.0|           1.0|
|    -8.0|           1.0|
|    -1.0|           1.0|
|   -14.0|           1.0|
|   -16.0|           0.0|
|    18.0|           2.0|
+--------+--------------+
only showing top 20 rows

Indexing Our String Fields into Numeric Fields


In [35]:
#
# Extract features tools in with pyspark.ml.feature
#
from pyspark.ml.feature import StringIndexer, VectorAssembler

# Turn category fields into categoric feature vectors, then drop intermediate fields
for column in ["Carrier", "DayOfMonth", "DayOfWeek", "DayOfYear",
               "Origin", "Dest", "Route"]:
  string_indexer = StringIndexer(
    inputCol=column,
    outputCol=column + "_index"
  )
  ml_bucketized_features = string_indexer.fit(ml_bucketized_features)\
                                          .transform(ml_bucketized_features)

# Check out the indexes
ml_bucketized_features.show(6)


+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+-------------+----------------+---------------+---------------+------------+----------+-----------+
|ArrDelay|          CRSArrTime|          CRSDepTime|Carrier|DayOfMonth|DayOfWeek|DayOfYear|DepDelay|Dest|Distance|FlightDate|FlightNum|Origin|  Route|ArrDelayBucket|Carrier_index|DayOfMonth_index|DayOfWeek_index|DayOfYear_index|Origin_index|Dest_index|Route_index|
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+-------------+----------------+---------------+---------------+------------+----------+-----------+
|    13.0|2015-01-01 10:10:...|2015-01-01 07:30:...|     AA|         1|        4|        1|    14.0| DFW|   569.0|2014-12-31|     1024|   ABQ|ABQ-DFW|           2.0|          2.0|            25.0|            0.0|          320.0|        53.0|       2.0|      938.0|
|    17.0|2015-01-01 02:15:...|2014-12-31 23:25:...|     AA|         1|        4|        1|    14.0| DFW|   569.0|2014-12-31|     1184|   ABQ|ABQ-DFW|           2.0|          2.0|            25.0|            0.0|          320.0|        53.0|       2.0|      938.0|
|    36.0|2015-01-01 03:45:...|2015-01-01 01:00:...|     AA|         1|        4|        1|    -2.0| DFW|   569.0|2014-12-31|      336|   ABQ|ABQ-DFW|           3.0|          2.0|            25.0|            0.0|          320.0|        53.0|       2.0|      938.0|
|   -21.0|2015-01-01 11:30:...|2015-01-01 09:55:...|     AA|         1|        4|        1|    -1.0| DFW|   731.0|2014-12-31|      125|   ATL|ATL-DFW|           0.0|          2.0|            25.0|            0.0|          320.0|         0.0|       2.0|       37.0|
|   -14.0|2015-01-01 02:25:...|2015-01-01 00:55:...|     AA|         1|        4|        1|    -4.0| DFW|   731.0|2014-12-31|     1455|   ATL|ATL-DFW|           1.0|          2.0|            25.0|            0.0|          320.0|         0.0|       2.0|       37.0|
|    16.0|2015-01-01 07:15:...|2015-01-01 05:45:...|     AA|         1|        4|        1|    15.0| DFW|   731.0|2014-12-31|     1473|   ATL|ATL-DFW|           2.0|          2.0|            25.0|            0.0|          320.0|         0.0|       2.0|       37.0|
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+-------------+----------------+---------------+---------------+------------+----------+-----------+
only showing top 6 rows

Combining Numeric Fields into a Single Vector


In [36]:
# Handle continuous, numeric fields by combining them into one feature vector
numeric_columns = ["DepDelay", "Distance"]
index_columns = ["Carrier_index", "DayOfMonth_index",
                   "DayOfWeek_index", "DayOfYear_index", "Origin_index",
                   "Origin_index", "Dest_index", "Route_index"]
vector_assembler = VectorAssembler(
  inputCols=numeric_columns + index_columns,
  outputCol="Features_vec"
)
final_vectorized_features = vector_assembler.transform(ml_bucketized_features)

# Drop the index columns
for column in index_columns:
  final_vectorized_features = final_vectorized_features.drop(column)

# Check out the features
final_vectorized_features.show()


+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+--------------------+
|ArrDelay|          CRSArrTime|          CRSDepTime|Carrier|DayOfMonth|DayOfWeek|DayOfYear|DepDelay|Dest|Distance|FlightDate|FlightNum|Origin|  Route|ArrDelayBucket|        Features_vec|
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+--------------------+
|    13.0|2015-01-01 10:10:...|2015-01-01 07:30:...|     AA|         1|        4|        1|    14.0| DFW|   569.0|2014-12-31|     1024|   ABQ|ABQ-DFW|           2.0|[14.0,569.0,2.0,2...|
|    17.0|2015-01-01 02:15:...|2014-12-31 23:25:...|     AA|         1|        4|        1|    14.0| DFW|   569.0|2014-12-31|     1184|   ABQ|ABQ-DFW|           2.0|[14.0,569.0,2.0,2...|
|    36.0|2015-01-01 03:45:...|2015-01-01 01:00:...|     AA|         1|        4|        1|    -2.0| DFW|   569.0|2014-12-31|      336|   ABQ|ABQ-DFW|           3.0|[-2.0,569.0,2.0,2...|
|   -21.0|2015-01-01 11:30:...|2015-01-01 09:55:...|     AA|         1|        4|        1|    -1.0| DFW|   731.0|2014-12-31|      125|   ATL|ATL-DFW|           0.0|[-1.0,731.0,2.0,2...|
|   -14.0|2015-01-01 02:25:...|2015-01-01 00:55:...|     AA|         1|        4|        1|    -4.0| DFW|   731.0|2014-12-31|     1455|   ATL|ATL-DFW|           1.0|[-4.0,731.0,2.0,2...|
|    16.0|2015-01-01 07:15:...|2015-01-01 05:45:...|     AA|         1|        4|        1|    15.0| DFW|   731.0|2014-12-31|     1473|   ATL|ATL-DFW|           2.0|[15.0,731.0,2.0,2...|
|    -7.0|2015-01-01 04:15:...|2015-01-01 02:45:...|     AA|         1|        4|        1|    -2.0| DFW|   731.0|2014-12-31|     1513|   ATL|ATL-DFW|           1.0|[-2.0,731.0,2.0,2...|
|    13.0|2015-01-01 08:50:...|2015-01-01 07:25:...|     AA|         1|        4|        1|     9.0| DFW|   731.0|2014-12-31|      194|   ATL|ATL-DFW|           2.0|[9.0,731.0,2.0,25...|
|    25.0|2015-01-01 12:30:...|2015-01-01 11:00:...|     AA|         1|        4|        1|    -2.0| DFW|   731.0|2014-12-31|      232|   ATL|ATL-DFW|           2.0|[-2.0,731.0,2.0,2...|
|    58.0|2015-01-01 13:40:...|2015-01-01 12:15:...|     AA|         1|        4|        1|    14.0| DFW|   731.0|2014-12-31|      276|   ATL|ATL-DFW|           3.0|[14.0,731.0,2.0,2...|
|    14.0|2015-01-01 05:25:...|2015-01-01 03:55:...|     AA|         1|        4|        1|    15.0| DFW|   731.0|2014-12-31|      314|   ATL|ATL-DFW|           2.0|[15.0,731.0,2.0,2...|
|     1.0|2015-01-01 10:05:...|2015-01-01 08:40:...|     AA|         1|        4|        1|    -5.0| DFW|   731.0|2014-12-31|      356|   ATL|ATL-DFW|           2.0|[-5.0,731.0,2.0,2...|
|   -29.0|2015-01-01 02:12:...|2015-01-01 00:15:...|     AA|         1|        4|        1|    -9.0| MIA|   594.0|2014-12-31|     1652|   ATL|ATL-MIA|           0.0|[-9.0,594.0,2.0,2...|
|   -10.0|2015-01-01 00:52:...|2014-12-31 23:00:...|     AA|         1|        4|        1|    -4.0| MIA|   594.0|2014-12-31|       17|   ATL|ATL-MIA|           1.0|[-4.0,594.0,2.0,2...|
|    -3.0|2015-01-01 15:02:...|2015-01-01 13:10:...|     AA|         1|        4|        1|    -7.0| MIA|   594.0|2014-12-31|      349|   ATL|ATL-MIA|           1.0|[-7.0,594.0,2.0,2...|
|    -8.0|2015-01-01 06:35:...|2015-01-01 05:30:...|     AA|         1|        4|        1|    -2.0| DFW|   190.0|2014-12-31|     1023|   AUS|AUS-DFW|           1.0|[-2.0,190.0,2.0,2...|
|    -1.0|2014-12-31 22:50:...|2014-12-31 21:50:...|     AA|         1|        4|        1|    -2.0| DFW|   190.0|2014-12-31|     1178|   AUS|AUS-DFW|           1.0|[-2.0,190.0,2.0,2...|
|   -14.0|2015-01-01 01:40:...|2015-01-01 00:30:...|     AA|         1|        4|        1|    -6.0| DFW|   190.0|2014-12-31|     1296|   AUS|AUS-DFW|           1.0|[-6.0,190.0,2.0,2...|
|   -16.0|2015-01-01 02:15:...|2015-01-01 01:05:...|     AA|         1|        4|        1|    -4.0| DFW|   190.0|2014-12-31|     1356|   AUS|AUS-DFW|           0.0|[-4.0,190.0,2.0,2...|
|    18.0|2015-01-01 08:55:...|2015-01-01 07:55:...|     AA|         1|        4|        1|     3.0| DFW|   190.0|2014-12-31|     1365|   AUS|AUS-DFW|           2.0|[3.0,190.0,2.0,25...|
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+--------------------+
only showing top 20 rows

Training Our Model in an Experimental Setup


In [37]:
#
# Cross validate, train and evaluate classifier
#

# Test/train split
training_data, test_data = final_vectorized_features.randomSplit([0.7, 0.3])

# Instantiate and fit random forest classifier
from pyspark.ml.classification import RandomForestClassifier
rfc = RandomForestClassifier(
  featuresCol="Features_vec",
  labelCol="ArrDelayBucket",
  maxBins=4657,
  maxMemoryInMB=1024
)
model = rfc.fit(training_data)

# Evaluate model using test data
predictions = model.transform(test_data)

from pyspark.ml.evaluation import MulticlassClassificationEvaluator
evaluator = MulticlassClassificationEvaluator(labelCol="ArrDelayBucket", metricName="accuracy")
accuracy = evaluator.evaluate(predictions)
print("Accuracy = {}".format(accuracy))

# Check a sample
predictions.sample(False, 0.001, 18).orderBy("CRSDepTime").show(6)


Accuracy = 0.5965632172973679
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+--------------------+--------------------+--------------------+----------+
|ArrDelay|          CRSArrTime|          CRSDepTime|Carrier|DayOfMonth|DayOfWeek|DayOfYear|DepDelay|Dest|Distance|FlightDate|FlightNum|Origin|  Route|ArrDelayBucket|        Features_vec|       rawPrediction|         probability|prediction|
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+--------------------+--------------------+--------------------+----------+
|   -11.0|2014-12-31 22:13:...|2014-12-31 21:10:...|     AA|         1|        4|        1|    -3.0| MIA|   192.0|2014-12-31|     1323|   MCO|MCO-MIA|           1.0|[-3.0,192.0,2.0,2...|[3.15496188966384...|[0.15774809448319...|       1.0|
|    -3.0|2015-01-01 03:40:...|2015-01-01 02:00:...|     AS|         1|        4|        1|    -9.0| OME|   539.0|2014-12-31|      151|   ANC|ANC-OME|           1.0|[-9.0,539.0,9.0,2...|[4.02192166004895...|[0.20109608300244...|       1.0|
|   -10.0|2015-01-01 13:10:...|2015-01-01 05:25:...|     VX|         1|        4|        1|     0.0| FLL|  2342.0|2014-12-31|      330|   LAX|LAX-FLL|           1.0|[0.0,2342.0,13.0,...|[4.40711487879725...|[0.22035574393986...|       1.0|
|     3.0|2015-01-01 06:25:...|2015-01-01 05:25:...|     AA|         1|        4|        1|    -1.0| TUL|   237.0|2014-12-31|     1563|   DFW|DFW-TUL|           2.0|[-1.0,237.0,2.0,2...|[2.53084153459117...|[0.12654207672955...|       1.0|
|   -14.0|2015-01-01 07:52:...|2015-01-01 06:38:...|     OO|         1|        4|        1|    -9.0| BUR|   326.0|2014-12-31|     6400|   SFO|SFO-BUR|           1.0|[-9.0,326.0,3.0,2...|[3.22466860417944...|[0.16123343020897...|       1.0|
|     3.0|2015-01-01 15:19:...|2015-01-01 10:45:...|     B6|         1|        4|        1|    -2.0| BOS|  1698.0|2014-12-31|     1038|   AUS|AUS-BOS|           2.0|[-2.0,1698.0,7.0,...|[4.78266368992140...|[0.23913318449607...|       1.0|
+--------+--------------------+--------------------+-------+----------+---------+---------+--------+----+--------+----------+---------+------+-------+--------------+--------------------+--------------------+--------------------+----------+
only showing top 6 rows


In [ ]: